{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from absl import logging\n",
    "import tensorflow.compat.v1 as tf\n",
    "\n",
    "from open_spiel.python import policy\n",
    "from open_spiel.python import rl_environment\n",
    "from open_spiel.python.algorithms import exploitability\n",
    "from open_spiel.python.algorithms import nfsp\n",
    "from open_spiel.python.pytorch import nfsp as nfsp_pt\n",
    "\n",
    "class NFSPPolicies(policy.Policy):\n",
    "  \"\"\"Joint policy to be evaluated.\"\"\"\n",
    "\n",
    "  def __init__(self, env, nfsp_policies, mode):\n",
    "    game = env.game\n",
    "    player_ids = [0, 1]\n",
    "    super(NFSPPolicies, self).__init__(game, player_ids)\n",
    "    self._policies = nfsp_policies\n",
    "    self._mode = mode\n",
    "    self._obs = {\"info_state\": [None, None], \"legal_actions\": [None, None]}\n",
    "\n",
    "  def action_probabilities(self, state, player_id=None):\n",
    "    cur_player = state.current_player()\n",
    "    legal_actions = state.legal_actions(cur_player)\n",
    "\n",
    "    self._obs[\"current_player\"] = cur_player\n",
    "    self._obs[\"info_state\"][cur_player] = (\n",
    "        state.information_state_tensor(cur_player))\n",
    "    self._obs[\"legal_actions\"][cur_player] = legal_actions\n",
    "\n",
    "    info_state = rl_environment.TimeStep(\n",
    "        observations=self._obs, rewards=None, discounts=None, step_type=None)\n",
    "\n",
    "    with self._policies[cur_player].temp_mode_as(self._mode):\n",
    "      p = self._policies[cur_player].step(info_state, is_evaluation=True).probs\n",
    "    prob_dict = {action: p[action] for action in legal_actions}\n",
    "    return prob_dict\n",
    "\n",
    "\n",
    "def tf_main(game,\n",
    "            env_config,\n",
    "            num_train_episodes,\n",
    "            eval_every,\n",
    "            hidden_layers_sizes,\n",
    "            replay_buffer_capacity,\n",
    "            reservoir_buffer_capacity,\n",
    "            anticipatory_param):\n",
    "  env = rl_environment.Environment(game, **env_configs)\n",
    "  info_state_size = env.observation_spec()[\"info_state\"][0]\n",
    "  num_actions = env.action_spec()[\"num_actions\"]\n",
    "\n",
    "  hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
    "  kwargs = {\n",
    "      \"replay_buffer_capacity\": replay_buffer_capacity,\n",
    "      \"epsilon_decay_duration\": num_train_episodes,\n",
    "      \"epsilon_start\": 0.06,\n",
    "      \"epsilon_end\": 0.001,\n",
    "  }\n",
    "  expl_list = []\n",
    "  with tf.Session() as sess:\n",
    "    # pylint: disable=g-complex-comprehension\n",
    "    agents = [\n",
    "        nfsp.NFSP(sess, idx, info_state_size, num_actions, hidden_layers_sizes,\n",
    "                  reservoir_buffer_capacity, anticipatory_param,\n",
    "                  **kwargs) for idx in range(num_players)\n",
    "    ]\n",
    "    expl_policies_avg = NFSPPolicies(env, agents, nfsp.MODE.average_policy)\n",
    "\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    for ep in range(num_train_episodes):\n",
    "      if (ep + 1) % eval_every == 0:\n",
    "        losses = [agent.loss for agent in agents]\n",
    "        print(\"Losses: %s\" %losses)\n",
    "        expl = exploitability.exploitability(env.game, expl_policies_avg)\n",
    "        expl_list.append(expl)\n",
    "        print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n",
    "        print(\"_____________________________________________\")\n",
    "\n",
    "      time_step = env.reset()\n",
    "      while not time_step.last():\n",
    "        player_id = time_step.observations[\"current_player\"]\n",
    "        agent_output = agents[player_id].step(time_step)\n",
    "        action_list = [agent_output.action]\n",
    "        time_step = env.step(action_list)\n",
    "\n",
    "      # Episode is over, step all agents with final info state.\n",
    "      for agent in agents:\n",
    "        agent.step(time_step)\n",
    "  return expl_list\n",
    "        \n",
    "def pt_main(game,\n",
    "            env_config,\n",
    "            num_train_episodes,\n",
    "            eval_every,\n",
    "            hidden_layers_sizes,\n",
    "            replay_buffer_capacity,\n",
    "            reservoir_buffer_capacity,\n",
    "            anticipatory_param):\n",
    "  env = rl_environment.Environment(game, **env_configs)\n",
    "  info_state_size = env.observation_spec()[\"info_state\"][0]\n",
    "  num_actions = env.action_spec()[\"num_actions\"]\n",
    "\n",
    "  hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
    "  kwargs = {\n",
    "      \"replay_buffer_capacity\": replay_buffer_capacity,\n",
    "      \"epsilon_decay_duration\": num_train_episodes,\n",
    "      \"epsilon_start\": 0.06,\n",
    "      \"epsilon_end\": 0.001,\n",
    "  }\n",
    "  expl_list = []\n",
    "  agents = [\n",
    "      nfsp_pt.NFSP(idx, info_state_size, num_actions, hidden_layers_sizes,\n",
    "                   reservoir_buffer_capacity, anticipatory_param,\n",
    "                   **kwargs) for idx in range(num_players)\n",
    "  ]\n",
    "  expl_policies_avg = NFSPPolicies(env, agents, nfsp_pt.MODE.average_policy)  \n",
    "  for ep in range(num_train_episodes):\n",
    "    if (ep + 1) % eval_every == 0:\n",
    "      losses = [agent.loss.item() for agent in agents]\n",
    "      print(\"Losses: %s\" %losses)\n",
    "      expl = exploitability.exploitability(env.game, expl_policies_avg)\n",
    "      expl_list.append(expl)\n",
    "      print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n",
    "      print(\"_____________________________________________\")  \n",
    "    time_step = env.reset()\n",
    "    while not time_step.last():\n",
    "      player_id = time_step.observations[\"current_player\"]\n",
    "      agent_output = agents[player_id].step(time_step)\n",
    "      action_list = [agent_output.action]\n",
    "      time_step = env.step(action_list)  \n",
    "    # Episode is over, step all agents with final info state.\n",
    "    for agent in agents:\n",
    "      agent.step(time_step)\n",
    "  return expl_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = \"kuhn_poker\"\n",
    "num_players = 2\n",
    "env_configs = {\"players\": num_players}\n",
    "num_train_episodes = int(3e6)\n",
    "eval_every = 10000\n",
    "hidden_layers_sizes = [128]\n",
    "replay_buffer_capacity = int(2e5)\n",
    "reservoir_buffer_capacity = int(2e6)\n",
    "anticipatory_param = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_kuhn_result = tf_main(game, \n",
    "                         env_configs,\n",
    "                         num_train_episodes,\n",
    "                         eval_every,\n",
    "                         hidden_layers_sizes,\n",
    "                         replay_buffer_capacity,\n",
    "                         reservoir_buffer_capacity,\n",
    "                         anticipatory_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pt_kuhn_result = pt_main(game, \n",
    "                         env_configs,\n",
    "                         num_train_episodes,\n",
    "                         eval_every,\n",
    "                         hidden_layers_sizes,\n",
    "                         replay_buffer_capacity,\n",
    "                         reservoir_buffer_capacity,\n",
    "                         anticipatory_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = [i*1000 for i in range(len(tf_kuhn_result))]\n",
    "\n",
    "plt.plot(x, tf_kuhn_result, label='tensorflow')\n",
    "plt.plot(x, pt_kuhn_result, label='pytorch')\n",
    "plt.title('Kuhn Poker')\n",
    "plt.xlabel('Episodes')\n",
    "plt.ylabel('Exploitability')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = \"leduc_poker\"\n",
    "num_players = 2\n",
    "env_configs = {\"players\": num_players}\n",
    "num_train_episodes = int(3e6)\n",
    "eval_every = 100000\n",
    "hidden_layers_sizes = [128]\n",
    "replay_buffer_capacity = int(2e5)\n",
    "reservoir_buffer_capacity = int(2e6)\n",
    "anticipatory_param = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_leduc_result = tf_main(game, \n",
    "                          env_configs,\n",
    "                          num_train_episodes,\n",
    "                          eval_every,\n",
    "                          hidden_layers_sizes,\n",
    "                          replay_buffer_capacity,\n",
    "                          reservoir_buffer_capacity,\n",
    "                          anticipatory_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pt_leduc_result = pt_main(game, \n",
    "                          env_configs,\n",
    "                          num_train_episodes,\n",
    "                          eval_every,\n",
    "                          hidden_layers_sizes,\n",
    "                          replay_buffer_capacity,\n",
    "                          reservoir_buffer_capacity,\n",
    "                          anticipatory_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [i * 10000 for i in range(len(tf_leduc_result))]\n",
    "\n",
    "plt.plot(x, tf_leduc_result, label='tensorflow')\n",
    "plt.plot(x, pt_leduc_result, label='pytorch')\n",
    "plt.title('Leduc Poker')\n",
    "plt.xlabel('Episodes')\n",
    "plt.ylabel('Exploitability')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
